import cdtools
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import (center_probe, center_probe_fourier,
                              split_dataset, fourier_pad)

view = True # Whether or not to plot the results as the calculation proceeds
if view:
    from matplotlib import pyplot as plt

device = 'cuda' # The device to run the reconstruction on
    
# Load the data
data_filename = 'data/single_shot_ptycho.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(data_filename,
                                                    cut_zeros=False)

# We add 10 to the dataset, to enable our algorithm to find the
# low-intensity flat-field which otherwise goes negative
dataset.patterns = t.clamp(dataset.patterns + 10, min=0)

# We pad everything here for superresolution
to_pad = 200
dataset.patterns = t.nn.functional.pad(dataset.patterns,
                                       (to_pad ,to_pad, to_pad, to_pad))
dataset.mask = t.nn.functional.pad(dataset.mask,
                                   (to_pad ,to_pad, to_pad, to_pad), value=0)

# We split the dataset into two disjoint halves to calculate an FRC
dataset_1, dataset_2 = split_dataset(dataset)
dataset_full = dataset

for plan, dataset in [('half_2', dataset_2),
                      ('half_1', dataset_1),
                      ('full', dataset_full)]:

    initialization_file = f'results/single_shot_ptycho_{plan}_rough.mat'
    savefile = f'results/single_shot_ptycho_{plan}_polished_{to_pad}pad.mat'
    print('To save in', savefile)

    # We remove a few key patterns which caused issues from the dataset
    if plan == 'half_1':
        dead_idx = [-1,None]
    elif plan == 'half_2':
        # These specific patterns seem to cause extra problems with the second
        # half reconstruction, so we exclude them
        dead_idx = [11,58,126,128,129]
        dead_idx = [-1] + dead_idx + [None]
    elif plan == 'full':
        # We also exclude these patterns from the full reconstruction (their
        # indices are different in thefull dataset)
        dead_idx = [21,112,249,253,255]
        dead_idx = [-1] + dead_idx + [None]

    live_ranges = list(zip(np.array(dead_idx[:-1])+1,dead_idx[1:]))
    dataset.patterns = t.cat([dataset.patterns[start:end] for start, end
                              in live_ranges], dim=0)
    dataset.translations = t.cat([dataset.translations[start:end] for start, end
                                  in live_ranges], dim=0)
    dataset.intensities = t.cat([dataset.intensities[start:end] for start, end
                                 in live_ranges], dim=0)

    if view:
        dataset.inspect()
        
    rough_results = loadmat(initialization_file)
    
    # Next, we create a ptychography model from the dataset. Using
    # probe_fourier_crop keeps the probe to the original resolution,
    # even though the object's resolution is increased
    model = cdtools.models.FancyPtycho.from_dataset(
        dataset,
        n_modes=3,
        propagation_distance=30e-6,# This is needed to set the probe_norm well
        translation_scale=10,
        probe_fourier_crop=(to_pad if to_pad != 0 else None),
        probe_support_radius=300,
        simulate_probe_translation=False,
        obj_view_crop=-200,
        units='um',
    )
    
    # We initialize with the probe from the rough reconstruction, centered
    probe = model.probe_support * t.as_tensor(rough_results['probe'])
    centered_probe = center_probe(probe)
    model.probe.data = centered_probe / model.probe_norm
    model.probe.data *= model.probe_support

    # We also initialize with the background from the rough reconstruction
    bkg = t.clamp(t.as_tensor(rough_results['background']),0)
    model.background.data = t.sqrt(t.as_tensor(t.sqrt(bkg)))
    model.background.data = t.nn.functional.pad(
        model.background.data, (to_pad ,to_pad, to_pad, to_pad))
    
    # we move to the correct device
    model.to(device=device)
    dataset.get_as(device=device)

    if view:
        model.inspect()
        
    # We start with two rounds of aggressive reconstructions, zeroing out the
    # object each time, with the goal of recovering good probe positions before
    # starting the final reconstruction
    for i in range(2):
        model.probe.requires_grad=False
        for loss in model.Adam_optimize(10, dataset, lr=0.03, batch_size=25, schedule=False):
            print(model.report())
            if model.epoch % 5 == 0 and view:
                model.inspect(dataset)
                
        model.probe.requires_grad=True
        for loss in model.Adam_optimize(50, dataset, lr=0.01, batch_size=25, schedule=False):
            print(model.report())
            if model.epoch % 5 == 0 and view:
                model.inspect(dataset)

        model.obj.data[:] = 1

    # Now we do a final round of object initialization without probe refinement
    model.probe.requires_grad=False
    for loss in model.Adam_optimize(10, dataset, lr=0.03, batch_size=25, schedule=False):
        print(model.report())
        if model.epoch % 5 == 0 and view:
            model.inspect(dataset)

    # And we finish the reconstruction with a final, long round including
    # learning rate scheduling to really converge well
    model.probe.requires_grad=True
    for loss in model.Adam_optimize(500, dataset, lr=0.03, batch_size=200, schedule=True):
        print(model.report())
        if model.epoch % 5 == 0 and view:
            model.inspect(dataset)    

    # We orthogonalize theresultingprobes
    model.tidy_probes()

    # And save the results
    savemat(savefile, model.save_results(dataset))

    # Finally, we plot the results
    if view:
        model.inspect(dataset)
        model.compare(dataset)

if view:
    plt.show()
